# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
import math
import numpy as np
from utils.multilevel_rois import map_rois_to_fpn_levels


# from core.config import cfg
# from datasets import json_dataset
# import modeling.FPN as fpn
# import roi_data.fast_rcnn
# import utils.blob as blob_utils


class CollectAndDistributeFpnRpnProposals(torch.nn.Module):
    def __init__(self, spatial_scales, train=False):
        super(CollectAndDistributeFpnRpnProposals, self).__init__()
        self._train = train
        self.rpn_levels = [int(math.log2(1 / s)) for s in spatial_scales]
        self.rpn_min_level = self.rpn_levels[0]
        self.rpn_max_level = self.rpn_levels[-1]

    def forward(self, roi_list, roi_score_list):
        """See modeling.detector.CollectAndDistributeFpnRpnProposals for
        inputs/outputs documentation.
        """
        # inputs is
        # [rpn_rois_fpn2, ..., rpn_rois_fpn6,
        #  rpn_roi_probs_fpn2, ..., rpn_roi_probs_fpn6]
        # If training with Faster R-CNN, then inputs will additionally include
        #  + [roidb, im_info]
        rois = collect(roi_list, roi_score_list, self._train)

        #           ************** WARNING ***************
        #      TRAINING CODE BELOW NOT CONVERTED TO PYTORCH
        #           ************** WARNING ***************

        # if self._train:
        #     # During training we reuse the data loader code. We populate roidb
        #     # entries on the fly using the rois generated by RPN.
        #     # im_info: [[im_height, im_width, im_scale], ...]
        #     im_info = inputs[-1].data
        #     im_scales = im_info[:, 2]
        #     roidb = blob_utils.deserialize(inputs[-2].data)
        #     # For historical consistency with the original Faster R-CNN
        #     # implementation we are *not* filtering crowd proposals.
        #     # This choice should be investigated in the future (it likely does
        #     # not matter).
        #     json_dataset.add_proposals(roidb, rois, im_scales, crowd_thresh=0)
        #     # Compute training labels for the RPN proposals; also handles
        #     # distributing the proposals over FPN levels
        #     output_blob_names = roi_data.fast_rcnn.get_fast_rcnn_blob_names()
        #     blobs = {k: [] for k in output_blob_names}
        #     roi_data.fast_rcnn.add_fast_rcnn_blobs(blobs, im_scales, roidb)
        #     for i, k in enumerate(output_blob_names):
        #         blob_utils.py_op_copy_blob(blobs[k], outputs[i])
        # else:
        #     # For inference we have a special code path that avoids some data
        #     # loader overhead
        #     distribute(rois, None, outputs, self._train)
        return distribute(rois, self.rpn_min_level, self.rpn_max_level)  # , None, outputs, self._train)


def collect(roi_inputs, score_inputs, train):
    # cfg_key = 'TRAIN' if is_training else 'TEST'
    post_nms_topN = 2000 if train else 1000  # cfg[cfg_key].RPN_POST_NMS_TOP_N
    # k_max = 6 #cfg.FPN.RPN_MAX_LEVEL
    # k_min = 2 #cfg.FPN.RPN_MIN_LEVEL
    # num_lvls = k_max - k_min + 1
    # roi_inputs = inputs[:num_lvls]
    # score_inputs = inputs[num_lvls:]
    # if is_training:
    #     score_inputs = score_inputs[:-2]

    # rois are in [[batch_idx, x0, y0, x1, y2], ...] format
    # Combine predictions across all levels and retain the top scoring
    # rois = np.concatenate([blob.data for blob in roi_inputs])
    rois = torch.cat(tuple(roi_inputs), 0)
    # scores = np.concatenate([blob.data for blob in score_inputs]).squeeze()
    scores = torch.cat(tuple(score_inputs), 0).squeeze()
    # inds = np.argsort(-scores)[:post_nms_topN]
    vals, inds = torch.sort(-scores)
    # rois = rois[inds, :]
    rois = rois[inds[:post_nms_topN], :]
    return rois


def distribute(rois, lvl_min, lvl_max):  # , label_blobs, outputs, train):
    """To understand the output blob order see return value of
    roi_data.fast_rcnn.get_fast_rcnn_blob_names(is_training=False)
    """
    # lvl_min = 2 #cfg.FPN.ROI_MIN_LEVEL
    # lvl_max = 5 #cfg.FPN.ROI_MAX_LEVEL
    lvls = map_rois_to_fpn_levels(rois.data.cpu().numpy(), lvl_min, lvl_max)

    # outputs[0].reshape(rois.shape)
    # outputs[0].data[...] = rois

    # Create new roi blobs for each FPN level
    # (See: modeling.FPN.add_multilevel_roi_blobs which is similar but annoying
    # to generalize to support this particular case.)
    rois_idx_order = np.empty((0,))
    distr_rois = []
    for output_idx, lvl in enumerate(range(lvl_min, lvl_max + 1)):
        idx_lvl = np.where(lvls == lvl)[0]
        distr_rois.append(rois[idx_lvl, :])
        rois_idx_order = np.concatenate((rois_idx_order, idx_lvl))
    rois_idx_restore = np.argsort(rois_idx_order)
    return distr_rois, rois_idx_restore
